from kmeans_l2 import Kmeans_l2
from k_median_alg import K_median_alg
from result_collection import Result_collection, Run_results
import util
import numpy as np
import z34alg
import time
import pickle




def run_experiment(dataset, subset_sizes, n_runs, algorithm, k, res_collection):
    print("getting original cost")
    start_time = time.time()
    original_cost = algorithm.get_original_cost(dataset)
    print(f"time to get original cost: {time.time() - start_time}")

    for subset_size in subset_sizes:
        run_res = Run_results(k, subset_size, n_runs)
        run_res.set_originalcost(original_cost)
        run_res.set_centers_full(algorithm.get_centers_full())
        for run in range(n_runs):
            print(f"subset size: {subset_size}, k: {k}, run: {run}")
            subset = dataset[np.random.choice(dataset.shape[0], subset_size, replace=False), :]
            subset_cost, subset_solution_original = algorithm.get_subset_and_original(dataset, subset)
            # run_res.add_subsetcost(subset_cost)
            run_res.add_subset_solution_original(subset_solution_original)
            run_res.add_centers_subset(algorithm.get_centers_subset())
        res_collection.add_result(run_res)
    return res_collection

def get_datasets():
    datasets_ = util.load_normalized_datasets()
    return datasets_

def run_kmeans_experiment():
    datasets = util.load_normalized_datasets([util.Dataset.MUSHROOM, util.Dataset.SKIN_NON_SKIN, util.Dataset.COVTYPE, util.Dataset.MNIST])
    k_values = [10,20,30,50]
    for dataset in datasets:
        print("dataset: ", dataset[1])
        min_subset_size = int(np.log2(np.max(k_values)))+1
        max_subset_size = 13
        subset_sizes = [2**i for i in range(min_subset_size, max_subset_size)]
        n_runs = 5
        
        res = run_experiments_kmeans(dataset[0],dataset[1], k_values, subset_sizes, n_runs,n_init=1)
        
        with open(f"results3/results_{util.Algorithm.KMEANS.value}_{dataset[1]}.pkl", "wb") as f:
            pickle.dump(res, f)
        
        print(res)

def run_experiments_kmeans(dataset, dataset_name, k_values, subset_sizes, n_runs, n_init=3):
    algorithms = [Kmeans_l2]
    algorithm_res = dict()
    for algorithm in algorithms:
        res_collection = Result_collection(k_values, subset_sizes, n_runs, dataset_name, algorithm.__class__.__name__)
        for k in k_values:
            algorithm = algorithms[0](k, n_init)
            res_collection=run_experiment(dataset, subset_sizes, n_runs, algorithm, k, res_collection)
        algorithm_res = res_collection        
    return algorithm_res

def run_experiments_kmedian():
    datasets = util.load_normalized_datasets([util.Dataset.SKIN_NON_SKIN])
    algorithm_res = dict()
    k_values = [10,20,30,50]
    max_iter = 100
    n_init = 3
    n_runs = 5
    for dataset in datasets:
        print("dataset: ", dataset[1])
        min_subset_size = int(np.log2(np.max(k_values)))+1      
        max_subset_size = 13
        dataset_name = dataset[1]
        subset_sizes = [2**i for i in range(min_subset_size, max_subset_size)]
        res_collection = Result_collection(k_values, subset_sizes, n_runs, dataset_name, util.Algorithm.KMEDIAN.value)
        for k in k_values:
            algorithm = K_median_alg(k,max_iter, n_init)
            res_collection=run_experiment(dataset[0], subset_sizes, n_runs, algorithm, k, res_collection)
        algorithm_res = res_collection
        with open(f"results3/results_{util.Algorithm.KMEDIAN.value}_{dataset_name}.pkl", "wb") as f:
            pickle.dump(algorithm_res, f)


def run_experiments_z34():
    datasets = util.load_normalized_datasets([util.Dataset.MNIST])
    k_values = [10,20,30,50]
    max_iter = 100
    max_sgd =50
    n_runs = 5
    lr =0.001
    for z in [3,4]:
        for dataset in datasets:
            print("dataset: ", dataset[1])
            min_subset_size = int(np.log2(np.max(k_values)))+1
            max_subset_size = 13
            dataset_name = dataset[1]
            subset_sizes = [2**i for i in range(min_subset_size, max_subset_size)]
            res_collection = Result_collection(k_values, subset_sizes, n_runs, dataset_name, util.Algorithm.Z34.value, z=z)
            for k in k_values:
                algorithm = z34alg.Z34(k,max_iter, max_sgd,lr, z)
                res_collection = run_experiment(dataset[0], subset_sizes, n_runs, algorithm, k, res_collection)
            with open(f"results3/results_{util.Algorithm.Z34.value}_{dataset_name}_z{z}.pkl", "wb") as f:
                pickle.dump(res_collection, f)
"""create main"""
if __name__ == "__main__":
    start_time = time.time()
    print("Starting experiment")
    run_experiments_kmedian()
    print("starting kmeans")
    # run_kmeans_experiment()
    # run_experiments_z34()
    print(f"total time: {time.time() - start_time}")
    